{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import gym\n",
    "import gym_custom_envs\n",
    "import time\n",
    "import scipy.linalg as alg\n",
    "from gym import wrappers\n",
    "env = gym.make(\"CartPoleContinuous-v0\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "mc = env.masscart\n",
    "mp = env.masspole\n",
    "l = env.length\n",
    "g = env.gravity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "M = \n",
      "[[1.1   0.05 ]\n",
      " [0.05  0.025]] \n",
      "\n",
      "iM = \n",
      "[[ 1. -2.]\n",
      " [-2. 44.]] \n",
      "\n",
      "dtau_g = \n",
      "[[0.   0.  ]\n",
      " [0.   0.49]] \n",
      "\n",
      "b\n",
      "[[1. 0.]\n",
      " [0. 0.]] \n",
      "\n"
     ]
    }
   ],
   "source": [
    "M = np.matrix([\n",
    "    [mc+mp, mp*l],\n",
    "    [mp*l, mp*(l**2)]\n",
    "])\n",
    "\n",
    "print(\"M = \")\n",
    "print(M, '\\n')\n",
    "\n",
    "iM = np.linalg.inv(M)\n",
    "\n",
    "print(\"iM = \")\n",
    "print(iM, '\\n')\n",
    "# C(q, q_dot) will be zero as q_dot is zero\n",
    "\n",
    "dtau_g = np.matrix([\n",
    "    [0, 0],\n",
    "    [0, mp*g*l]\n",
    "])\n",
    "\n",
    "print(\"dtau_g = \")\n",
    "print(dtau_g, '\\n')\n",
    "\n",
    "\n",
    "b = np.matrix([\n",
    "    [1., 0],\n",
    "    [0., 0]\n",
    "])\n",
    "\n",
    "print(\"b\")\n",
    "print(b, '\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def lqr(A, B, Q, R):\n",
    "\n",
    "    S = np.matrix(alg.solve_continuous_are(A, B, Q, R))\n",
    "    \n",
    "    K = np.matrix(alg.inv(R)*B.T*S)\n",
    "        \n",
    "    return K, S"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "A = \n",
      "[[ 0.    0.    1.    0.  ]\n",
      " [ 0.    0.    0.    1.  ]\n",
      " [ 0.   -0.98  0.    0.  ]\n",
      " [ 0.   21.56  0.    0.  ]] \n",
      "\n",
      "B = \n",
      "[[ 0.  0.]\n",
      " [ 0.  0.]\n",
      " [ 1.  0.]\n",
      " [-2.  0.]] \n",
      "\n",
      "Q = \n",
      "[[10.  0.  0.  0.]\n",
      " [ 0. 10.  0.  0.]\n",
      " [ 0.  0.  1.  0.]\n",
      " [ 0.  0.  0.  1.]] \n",
      "\n",
      "R = \n",
      "[[1. 0.]\n",
      " [0. 1.]] \n",
      "\n"
     ]
    }
   ],
   "source": [
    "A = np.concatenate((\n",
    "    np.concatenate((np.zeros((2,2)), np.eye(2)), axis=1),\n",
    "    np.concatenate((np.dot(iM, dtau_g), np.zeros((2,2))), axis=1)\n",
    "), axis=0)\n",
    "    \n",
    "print(\"A = \")    \n",
    "print(A, '\\n')   \n",
    "\n",
    "B = np.concatenate((\n",
    "    np.zeros((2,2)),\n",
    "    np.dot(iM, b)\n",
    "), axis=0)\n",
    "    \n",
    "print(\"B = \")    \n",
    "print(B, '\\n')   \n",
    "\n",
    "Q = np.matrix([\n",
    "    [10., 0, 0, 0],\n",
    "    [0, 10., 0, 0],\n",
    "    [0, 0, 1., 0],\n",
    "    [0, 0, 0, 1.]\n",
    "])\n",
    "print(\"Q = \")    \n",
    "print(Q, '\\n')   \n",
    "\n",
    "R = np.eye(2)\n",
    "print(\"R = \")    \n",
    "print(R, '\\n')   \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "K = \n",
      "[[ -3.16227766 -37.08642453  -4.23995135  -8.18258944]\n",
      " [  0.           0.           0.           0.        ]] \n",
      "\n",
      "S = \n",
      "[[ 13.40790344  25.8756198    8.48859373   5.8254357 ]\n",
      " [ 25.8756198  151.6477398   28.86834548  32.977385  ]\n",
      " [  8.48859373  28.86834548   8.74105515   6.49050325]\n",
      " [  5.8254357   32.977385     6.49050325   7.33654635]] \n",
      "\n"
     ]
    }
   ],
   "source": [
    "K, S = lqr(A, B, Q, R)\n",
    "\n",
    "print(\"K = \")\n",
    "print(K, '\\n')\n",
    "print(\"S = \")\n",
    "print(S, '\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "successfully reached goal  10 times out of  10\n"
     ]
    }
   ],
   "source": [
    "from IPython.display import clear_output\n",
    "env = gym.wrappers.Monitor(env, \"vid\", video_callable=lambda i: True,force=True)\n",
    "win = 0\n",
    "number = 10\n",
    "for i in range(number):\n",
    "    s = env.reset().reshape(1, 4)\n",
    "    done = False\n",
    "    while not done:\n",
    "        a = -np.dot(K, s.T)[0, 0]\n",
    "#         print(a)\n",
    "        ns, r, done, _ = env.step(a)\n",
    "        env.render()\n",
    "        s = ns.reshape(1, 4)\n",
    "#         clear_output(wait=True)\n",
    "#         time.sleep(0.005)\n",
    "    if abs(ns[1]) < np.deg2rad(1):\n",
    "        win += 1    \n",
    "print('successfully reached goal ', win, 'times out of ', number)\n",
    "env.close()        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
